{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fe4cb25e4fb45586",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## Tutorial: Running Solution Optimization\n",
    "\n",
    "![TextGrad](https://github.com/vinid/data/blob/master/logo_full.png?raw=true)\n",
    "\n",
    "An autograd engine -- for textual gradients!\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/zou-group/TextGrad/blob/main/examples/notebooks/Prompt-Optimization.ipynb)\n",
    "[![GitHub license](https://img.shields.io/badge/License-MIT-blue.svg)](https://lbesson.mit-license.org/)\n",
    "[![Arxiv](https://img.shields.io/badge/arXiv-2406.07496-B31B1B.svg)](https://arxiv.org/abs/2406.07496)\n",
    "[![Documentation Status](https://readthedocs.org/projects/textgrad/badge/?version=latest)](https://textgrad.readthedocs.io/en/latest/?badge=latest)\n",
    "[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/textgrad)](https://pypi.org/project/textgrad/)\n",
    "[![PyPI](https://img.shields.io/pypi/v/textgrad)](https://pypi.org/project/textgrad/)\n",
    "\n",
    "**Objectives:**\n",
    "\n",
    "* In this tutorial, we will implement a solution optimization pipeline.\n",
    "\n",
    "**Requirements:**\n",
    "\n",
    "* You need to have an OpenAI API key to run this tutorial. This should be set as an environment variable as OPENAI_API_KEY.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f6e021565d0c914",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "!pip install textgrad # you might need to restart the notebook after installing textgrad\n",
    "\n",
    "\n",
    "import textgrad as tg\n",
    "tg.set_backward_engine(tg.get_engine(\"gpt-4o\"))\n",
    "\n",
    "initial_solution = \"\"\"To solve the equation 3x^2 - 7x + 2 = 0, we use the quadratic formula:\n",
    "x = (-b ± √(b^2 - 4ac)) / 2a\n",
    "a = 3, b = -7, c = 2\n",
    "x = (7 ± √((-7)^2 + 4(3)(2))) / 6\n",
    "x = (7 ± √73) / 6\n",
    "The solutions are:\n",
    "x1 = (7 + √73)\n",
    "x2 = (7 - √73)\"\"\"\n",
    "\n",
    "solution = tg.Variable(initial_solution,\n",
    "                       requires_grad=True,\n",
    "                       role_description=\"solution to the math question\")\n",
    "\n",
    "loss_system_prompt = tg.Variable(\"\"\"You will evaluate a solution to a math question. \n",
    "Do not attempt to solve it yourself, do not give a solution, only identify errors. Be super concise.\"\"\",\n",
    "                                 requires_grad=False,\n",
    "                                 role_description=\"system prompt\")\n",
    "                              \n",
    "loss_fn = tg.TextLoss(loss_system_prompt)\n",
    "optimizer = tg.TGD([solution])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "169c07c1-a01d-4309-9fd4-a6f9fb52daf4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Variable(value=Errors:\n",
       "1. Incorrect sign in the discriminant calculation: it should be b^2 - 4ac, not b^2 + 4ac.\n",
       "2. Incorrect simplification of the quadratic formula: the denominator should be 2a, not 6.\n",
       "3. Final solutions are missing the division by 2a., role=response from the language model, grads=)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss = loss_fn(solution)\n",
    "loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a2fce8c1-6838-4aaf-b830-de34effe44db",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "To solve the equation 3x^2 - 7x + 2 = 0, we use the quadratic formula:\n",
      "x = (-b ± √(b^2 - 4ac)) / 2a\n",
      "\n",
      "Given:\n",
      "a = 3, b = -7, c = 2\n",
      "\n",
      "Substitute the values into the formula:\n",
      "x = (7 ± √((-7)^2 - 4(3)(2))) / 6\n",
      "x = (7 ± √(49 - 24)) / 6\n",
      "x = (7 ± √25) / 6\n",
      "x = (7 ± 5) / 6\n",
      "\n",
      "The solutions are:\n",
      "x1 = (7 + 5) / 6 = 12 / 6 = 2\n",
      "x2 = (7 - 5) / 6 = 2 / 6 = 1/3\n"
     ]
    }
   ],
   "source": [
    "loss.backward()\n",
    "optimizer.step()\n",
    "print(solution.value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e9e4cb2-1c6e-499f-bfd6-efa47461f6a6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
